{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tell a Joke instruction dataset. Based on https://huggingface.co/datasets/SocialGrep/one-million-reddit-jokes and augmented using keybert."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/data/datasets/tell_a_joke/tell_a_joke.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install keybert\n",
    "!pip install datasets\n",
    "!pip install pandas\n",
    "!pip install huggingface_hub\n",
    "!pip install swifter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "from keybert import KeyBERT\n",
    "from huggingface_hub import notebook_login\n",
    "from datasets import Dataset\n",
    "import swifter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SIZE_LIMIT = 20000\n",
    "MIN_JOKE_SCORE = 100\n",
    "HF_DATASET = \"SocialGrep/one-million-reddit-jokes\"\n",
    "pd.set_option(\"display.max_colwidth\", None)\n",
    "jokes_ds = load_dataset(HF_DATASET)\n",
    "kw_model = KeyBERT()\n",
    "pd_jokes = jokes_ds[\"train\"].to_pandas()\n",
    "pd_jokes = pd_jokes.reset_index(drop=True)\n",
    "\n",
    "filtered_jokes = pd_jokes[pd_jokes[\"subreddit.nsfw\"] == False]\n",
    "filtered_jokes.dropna(subset=[\"selftext\"], inplace=True)\n",
    "filtered_jokes = filtered_jokes[filtered_jokes[\"score\"] > MIN_JOKE_SCORE]\n",
    "filtered_jokes = filtered_jokes[filtered_jokes[\"selftext\"] != \"[deleted]\"]\n",
    "filtered_jokes = filtered_jokes[filtered_jokes[\"selftext\"] != \"[removed]\"]\n",
    "filtered_jokes = filtered_jokes[~filtered_jokes[\"selftext\"].str.contains(\"edit:\", case=False)]\n",
    "\n",
    "filtered_jokes = filtered_jokes.sort_values(\"score\", ascending=False)\n",
    "\n",
    "filtered_jokes[[\"score\", \"title\", \"selftext\", \"subreddit.nsfw\"]]\n",
    "filtered_jokes = filtered_jokes[0:SIZE_LIMIT]\n",
    "print(len(filtered_jokes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "joke_requests = [\n",
    "    \"Can you share a joke that involves {}?\",\n",
    "    \"Do you know any jokes related to {}?\",\n",
    "    \"Could you tell me a funny joke that has to do with {}?\",\n",
    "    \"I'm in the mood for a joke about {}. Do you have any good ones?\",\n",
    "    \"Would you happen to have a joke about {} that you could tell me?\",\n",
    "    \"Can you think of a joke that centers around {}?\",\n",
    "    \"I'd love to hear a witty joke related to {}. Do you have one?\",\n",
    "    \"Tell me a humorous joke that involves {}.\",\n",
    "    \"Could you please entertain me with a joke related to {}?\",\n",
    "    \"What's a good joke that relates to {}?\",\n",
    "    \"I could use a good laugh. How about a joke about {}?\",\n",
    "    \"What's a funny joke that relates to {}?\",\n",
    "    \"Can you make me chuckle with a joke that involves {}?\",\n",
    "    \"I'm curious if you have a joke up your sleeve that pertains to {}?\",\n",
    "    \"Do you have a favorite joke that involves {}?\",\n",
    "    \"Mind sharing a joke with me that has to do with {}?\",\n",
    "    \"How about a joke related to {}? Do you have one?\",\n",
    "    \"I'm in need of a good joke. Something that centers around {} should do the trick.\",\n",
    "    \"Would you be willing to share a joke about {} with me?\",\n",
    "    \"Can you think of a joke that relates to {} that you could tell me?\",\n",
    "]\n",
    "\n",
    "\n",
    "def make_item(joke):\n",
    "    title = joke[\"title\"]\n",
    "    body = joke[\"selftext\"]\n",
    "    permalink = joke[\"permalink\"]\n",
    "    joke_text = f\"{title}\\n{body}\"\n",
    "    prefix = random.choice(joke_requests)\n",
    "\n",
    "    try:\n",
    "        keywords = kw_model.extract_keywords(joke_text, keyphrase_ngram_range=(1, 2), stop_words=\"english\")\n",
    "        main_keyword = keywords[0][0]\n",
    "        instruction = f\"{prefix.format(main_keyword)}\"\n",
    "    except Exception as e:\n",
    "        print(\"Error:\", e, joke_text, joke)\n",
    "        instruction = \"Could you tell me a random joke?\"\n",
    "\n",
    "    return pd.Series(\n",
    "        [instruction, joke_text, HF_DATASET, {\"nsfw\": False, \"link\": permalink}],\n",
    "        index=[\"INSTRUCTION\", \"RESPONSE\", \"SOURCE\", \"METADATA\"],\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time oa_format = filtered_jokes.swifter.apply(make_item, axis=1)\n",
    "print(len(oa_format))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oa_format.to_parquet(\"dataset.parquet\", row_group_size=100, engine=\"pyarrow\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = Dataset.from_parquet(\"dataset.parquet\")\n",
    "ds.push_to_hub(f\"mikegarts/oa_tell_a_joke_{SIZE_LIMIT}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
